{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import sys\n",
    "root = '/'.join(os.path.realpath('.').replace('\\\\','/').split('/'))\n",
    "p = root + '/CMMLU/src'\n",
    "if p not in sys.path:\n",
    "    sys.path.append(p)\n",
    "import argparse\n",
    "from CMMLU.src.mp_utils import choices, format_example, gen_prompt, softmax, run_eval\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "from transformers.generation.configuration_utils import GenerationConfig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "git clone -- depth 1 https://github.com/haonan-li/CMMLU.git\n",
    "```\n",
    "\n",
    "cpoied from https://github.com/haonan-li/CMMLU/blob/master/src/hf_causal_model.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = '../model_save/dpo'  # 模型文件在上一层目录，使用dpo后的模型\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# 加载模型\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_dir)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(device)\n",
    "generation_config = GenerationConfig()\n",
    "generation_config.remove_invalid_values = True  # 自动添加InfNanRemoveLogitsProcessor\n",
    "generation_config.eos_token_id = tokenizer.eos_token_id\n",
    "generation_config.pad_token_id = tokenizer.pad_token_id\n",
    "# for t5, set decoder_start_token_id = pad_token_id\n",
    "generation_config.decoder_start_token_id = tokenizer.pad_token_id  \n",
    "generation_config.max_new_tokens = 1\n",
    "generation_config.num_beams = 1\n",
    "generation_config.do_sample = False   # greedy search\n",
    "\n",
    "choices = ['A', 'B', 'C', 'D']\n",
    "choices_ids = [tokenizer.convert_tokens_to_ids(c) for c in choices]\n",
    "choices_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot):\n",
    "    choice_ids = [tokenizer.convert_tokens_to_ids(choice) for choice in choices]\n",
    "    cors = []\n",
    "    all_conf = []\n",
    "    all_preds = []\n",
    "    answers = choices[: test_df.shape[1] - 2]\n",
    "\n",
    "    for i in range(test_df.shape[0]):\n",
    "        prompt_end = format_example(test_df, i, subject, include_answer=False)\n",
    "        prompt = gen_prompt(dev_df=dev_df,\n",
    "                            subject=subject,\n",
    "                            prompt_end=prompt_end,\n",
    "                            num_few_shot=num_few_shot,\n",
    "                            tokenizer=tokenizer,\n",
    "                            max_length=max_length)\n",
    "        inputs = tokenizer([prompt])\n",
    "        if \"token_type_ids\" in inputs: # For Falcon\n",
    "            inputs.pop(\"token_type_ids\")\n",
    "        label = test_df.iloc[i, test_df.shape[1] - 1]\n",
    "        torch.cuda.empty_cache()\n",
    "        \n",
    "        input_ids, attention_mask = torch.LongTensor(inputs['input_ids']), torch.LongTensor(inputs['attention_mask'])\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            outputs = model.generate(\n",
    "            input_ids=input_ids.to(device),\n",
    "            attention_mask=attention_mask.to(device),\n",
    "            generation_config=generation_config,\n",
    "            return_dict_in_generate=True,\n",
    "            output_scores=True,\n",
    "        )\n",
    "            \n",
    "            scores = torch.stack(outputs['scores'], dim=1).to('cpu')\n",
    "            scores = torch.softmax(scores, dim=2)\n",
    "            scores = scores[...,  0, choices_ids]  #取第一个字符的ABCD概率\n",
    "            conf = scores[0][choices.index(label)]\n",
    "            choices_index = torch.argmax(scores)\n",
    "            \n",
    "            pred = choices[choices_index]\n",
    "\n",
    "        all_preds += pred\n",
    "        all_conf.append(conf)\n",
    "        cors.append(pred == label)\n",
    "\n",
    "    acc = np.mean(cors)\n",
    "    print(\"Average accuracy {:.3f} - {}\".format(acc, subject))\n",
    "    return acc, all_preds, conf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average accuracy 0.243 - agronomy\n",
      "Average accuracy 0.243 - anatomy\n",
      "Average accuracy 0.256 - ancient_chinese\n",
      "Average accuracy 0.256 - arts\n",
      "Average accuracy 0.248 - astronomy\n",
      "Average accuracy 0.234 - business_ethics\n",
      "Average accuracy 0.256 - chinese_civil_service_exam\n",
      "Average accuracy 0.260 - chinese_driving_rule\n",
      "Average accuracy 0.235 - chinese_food_culture\n",
      "Average accuracy 0.252 - chinese_foreign_policy\n",
      "Average accuracy 0.251 - chinese_history\n",
      "Average accuracy 0.250 - chinese_literature\n",
      "Average accuracy 0.246 - chinese_teacher_qualification\n",
      "Average accuracy 0.253 - clinical_knowledge\n",
      "Average accuracy 0.245 - college_actuarial_science\n",
      "Average accuracy 0.318 - college_education\n",
      "Average accuracy 0.302 - college_engineering_hydrology\n",
      "Average accuracy 0.213 - college_law\n",
      "Average accuracy 0.219 - college_mathematics\n",
      "Average accuracy 0.264 - college_medical_statistics\n",
      "Average accuracy 0.234 - college_medicine\n",
      "Average accuracy 0.240 - computer_science\n",
      "Average accuracy 0.263 - computer_security\n",
      "Average accuracy 0.252 - conceptual_physics\n",
      "Average accuracy 0.252 - construction_project_management\n",
      "Average accuracy 0.239 - economics\n",
      "Average accuracy 0.258 - education\n",
      "Average accuracy 0.250 - electrical_engineering\n",
      "Average accuracy 0.282 - elementary_chinese\n",
      "Average accuracy 0.242 - elementary_commonsense\n",
      "Average accuracy 0.282 - elementary_information_and_technology\n",
      "Average accuracy 0.283 - elementary_mathematics\n",
      "Average accuracy 0.252 - ethnology\n",
      "Average accuracy 0.252 - food_science\n",
      "Average accuracy 0.239 - genetics\n",
      "Average accuracy 0.242 - global_facts\n",
      "Average accuracy 0.272 - high_school_biology\n",
      "Average accuracy 0.235 - high_school_chemistry\n",
      "Average accuracy 0.271 - high_school_geography\n",
      "Average accuracy 0.250 - high_school_mathematics\n",
      "Average accuracy 0.255 - high_school_physics\n",
      "Average accuracy 0.252 - high_school_politics\n",
      "Average accuracy 0.254 - human_sexuality\n",
      "Average accuracy 0.249 - international_law\n",
      "Average accuracy 0.250 - journalism\n",
      "Average accuracy 0.253 - jurisprudence\n",
      "Average accuracy 0.252 - legal_and_moral_basis\n",
      "Average accuracy 0.252 - logical\n",
      "Average accuracy 0.238 - machine_learning\n",
      "Average accuracy 0.243 - management\n",
      "Average accuracy 0.250 - marketing\n",
      "Average accuracy 0.249 - marxist_theory\n",
      "Average accuracy 0.250 - modern_chinese\n",
      "Average accuracy 0.241 - nutrition\n",
      "Average accuracy 0.257 - philosophy\n",
      "Average accuracy 0.251 - professional_accounting\n",
      "Average accuracy 0.251 - professional_law\n",
      "Average accuracy 0.242 - professional_medicine\n",
      "Average accuracy 0.246 - professional_psychology\n",
      "Average accuracy 0.247 - public_relations\n",
      "Average accuracy 0.252 - security_study\n",
      "Average accuracy 0.252 - sociology\n",
      "Average accuracy 0.248 - sports_science\n",
      "Average accuracy 0.254 - traditional_chinese_medicine\n",
      "Average accuracy 0.243 - virology\n",
      "Average accuracy 0.242 - world_history\n",
      "Average accuracy 0.256 - world_religions\n",
      "STEM                                     25.16\n",
      "Humanities                               24.78\n",
      "Social Science                           25.42\n",
      "Other                                    25.15\n",
      "China specific                           25.26\n",
      "Overall                        25.17\n"
     ]
    }
   ],
   "source": [
    "from dataclasses import dataclass\n",
    "@dataclass\n",
    "class Args:\n",
    "    data_dir: str = './CMMLU/data'\n",
    "    save_dir: str = './result'\n",
    "    num_few_shot: int = 0\n",
    "    max_length: int = 512\n",
    "\n",
    "run_eval(model, tokenizer, eval, Args())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
